{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y2RgpN7yLf9J"
      },
      "source": [
        "# Using Elastic and OpenELM to Prototype Apple-Inspired AI\n",
        "\n",
        "This is the supporting material for [this blog post.](https://search-labs.elastic.co/search-labs/blog/using-openelm-models)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4y68v93nNXTH",
        "outputId": "a2d6f13d-c0ec-40af-f623-a043fe3654e6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting elasticsearch\n",
            "  Downloading elasticsearch-8.15.1-py3-none-any.whl.metadata (8.7 kB)\n",
            "Collecting elastic-transport<9,>=8.13 (from elasticsearch)\n",
            "  Downloading elastic_transport-8.15.0-py3-none-any.whl.metadata (3.6 kB)\n",
            "Requirement already satisfied: urllib3<3,>=1.26.2 in /usr/local/lib/python3.10/dist-packages (from elastic-transport<9,>=8.13->elasticsearch) (2.2.3)\n",
            "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from elastic-transport<9,>=8.13->elasticsearch) (2024.8.30)\n",
            "Downloading elasticsearch-8.15.1-py3-none-any.whl (524 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m524.6/524.6 kB\u001b[0m \u001b[31m31.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading elastic_transport-8.15.0-py3-none-any.whl (64 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: elastic-transport, elasticsearch\n",
            "Successfully installed elastic-transport-8.15.0 elasticsearch-8.15.1\n"
          ]
        }
      ],
      "source": [
        "%pip install elasticsearch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 65,
      "metadata": {
        "id": "IBWlj7OJLhmN"
      },
      "outputs": [],
      "source": [
        "from elasticsearch import Elasticsearch, helpers, exceptions, ConnectionTimeout\n",
        "from getpass import getpass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GEoWlRH2Ma9d",
        "outputId": "42a9fcf1-5f4c-4bdb-8fd0-7e68bd24dc07"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Elastic Cloud ID: ··········\n",
            "Elastic Api Key: ··········\n",
            "Huggingface Token: ··········\n"
          ]
        }
      ],
      "source": [
        "# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#finding-your-cloud-id\n",
        "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID: \")\n",
        "\n",
        "# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#creating-an-api-key\n",
        "ELASTIC_API_KEY = getpass(\"Elastic Api Key: \")\n",
        "\n",
        "# https://huggingface.co/docs/hub/en/security-tokens\n",
        "HUGGINGFACE_TOKEN = getpass(\"Huggingface Token: \")\n",
        "\n",
        "# https://huggingface.co/apple/OpenELM\n",
        "MODEL = \"apple/OpenELM-3B-Instruct\"\n",
        "\n",
        "# Create the client instance\n",
        "client = Elasticsearch(\n",
        "    # For local development\n",
        "    # hosts=[\"http://localhost:9200\"]\n",
        "    cloud_id=ELASTIC_CLOUD_ID,\n",
        "    api_key=ELASTIC_API_KEY,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vs0LMPgwLoFW"
      },
      "source": [
        "## 2. Deploy the OpenELM Model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5bWSYVrOLo9D",
        "outputId": "84460f5e-39ab-4cb2-ef52-351aa461cfa6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Cloning into 'OpenELM'...\n",
            "remote: Enumerating objects: 12, done.\u001b[K\n",
            "remote: Counting objects: 100% (11/11), done.\u001b[K\n",
            "remote: Compressing objects: 100% (11/11), done.\u001b[K\n",
            "remote: Total 12 (delta 4), reused 0 (delta 0), pack-reused 1 (from 1)\u001b[K\n",
            "Unpacking objects: 100% (12/12), 8.28 KiB | 2.07 MiB/s, done.\n"
          ]
        }
      ],
      "source": [
        "!git clone https://huggingface.co/apple/OpenELM"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "CDHtF8RaO62t"
      },
      "outputs": [],
      "source": [
        "prompt = \"Once upon a time there was\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fMPhVXXoPsXY"
      },
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GqwDnJSSLth8",
        "outputId": "02c73b2b-1ea6-4371-a3e6-992e5f89b8a2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:root:inference device is not set, using cuda:0, NVIDIA A100-SXM4-40GB\n",
            "Loading checkpoint shards: 100% 2/2 [00:01<00:00,  1.25it/s]\n",
            "2024-09-30 04:46:39.147179: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
            "2024-09-30 04:46:39.163451: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
            "2024-09-30 04:46:39.181587: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
            "2024-09-30 04:46:39.186881: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
            "2024-09-30 04:46:39.200600: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
            "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
            "2024-09-30 04:46:40.313149: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
            "\n",
            "\n",
            "\u001b[1m Prompt + Generated Output\u001b[0m\n",
            "\n",
            "Once upon a time there was a little girl named Rosie. Rosie loved to play dress-up, and her favorite costume was a princess dress. Her mother dressed Rosie in the princess dress every day, but Rosie wanted to wear it on special occasions, too.\n",
            "\n",
            "One day Rosie's mother told Rosie that Prince Charming would arrive at their house later that evening. Rosie couldn't wait! Prince Charming would surely ask Rosie to marry him, and she would wear her beautiful princess dress for their wedding. Rosie ran upstairs to get dressed.\n",
            "\n",
            "When Prince Charming arrived, Rosie wore her princess dress and tiara. Her mother helped her with her makeup, and Rosie practiced her curtsy. Prince Charming smiled and kissed Rosie's hand. Rosie hoped Prince Charming would ask her to marry him, but Prince Charming shook Rosie's hand instead. Prince Charming explained that Rosie's father had died when Rosie was very young, and Prince Charming wanted Rosie to choose her own husband. Rosie felt sad, but Prince Charming promised her\n",
            "\n",
            "\n",
            "Generation took\u001b[1m\u001b[92m 9.71 \u001b[0mseconds.\n",
            "\n"
          ]
        }
      ],
      "source": [
        "!python /content/OpenELM/generate_openelm.py --model '{MODEL}' --hf_access_token '{HUGGINGFACE_TOKEN}' --prompt '{prompt}' --generate_kwargs repetition_penalty=1.2 prompt_lookup_num_tokens=10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iF4n3etCLwVs"
      },
      "source": [
        "## 3. Index Data in Elasticsearch\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 391
        },
        "id": "S71mSikTQWWT",
        "outputId": "8d468e94-2e07-4271-f263-5cfe7b717aea"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "    # client.options(request_timeout=5).inference.delete(inference_id=\"my-elser-model\")\n",
        "    client.options(request_timeout=5).inference.put(\n",
        "        task_type=\"sparse_embedding\",\n",
        "        inference_id=\"my-elser-model\",\n",
        "        body={\n",
        "            \"service\": \"elser\",\n",
        "            \"service_settings\": {\"num_allocations\": 1, \"num_threads\": 1},\n",
        "        },\n",
        "    )\n",
        "except ConnectionTimeout:\n",
        "    pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bs8GUJb2L0vQ",
        "outputId": "8394a345-4d40-4021-b8ed-af96497c0bfc"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'mobile-assistant'})"
            ]
          },
          "execution_count": 45,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Create the index\n",
        "index_name = \"mobile-assistant\"\n",
        "client.indices.delete(index=index_name, ignore_unavailable=True)\n",
        "index_body = {\n",
        "    \"mappings\": {\n",
        "        \"properties\": {\n",
        "            \"title\": {\"type\": \"text\", \"analyzer\": \"english\"},\n",
        "            \"description\": {\n",
        "                \"type\": \"text\",\n",
        "                \"analyzer\": \"english\",\n",
        "                \"copy_to\": \"semantic_field\",\n",
        "            },\n",
        "            \"semantic_field\": {\n",
        "                \"type\": \"semantic_text\",\n",
        "                \"inference_id\": \"my-elser-model\",\n",
        "            },\n",
        "        }\n",
        "    }\n",
        "}\n",
        "\n",
        "client.indices.create(index=index_name, body=index_body)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "id": "UTsn9uYxL3oN"
      },
      "outputs": [],
      "source": [
        "documents = [\n",
        "    {\n",
        "        \"_index\": index_name,\n",
        "        \"_id\": \"email1\",\n",
        "        \"title\": \"Team Meeting Agenda\",\n",
        "        \"description\": \"Hello team, Let's discuss our project progress in tomorrow's meeting. Please prepare your updates. Best regards, Manager\",\n",
        "    },\n",
        "    {\n",
        "        \"_index\": index_name,\n",
        "        \"_id\": \"email2\",\n",
        "        \"title\": \"Client Proposal Draft\",\n",
        "        \"description\": \"Hi, I've attached the draft of our client proposal. Could you review it and provide feedback? Thanks, Colleague\",\n",
        "    },\n",
        "    {\n",
        "        \"_index\": index_name,\n",
        "        \"_id\": \"email3\",\n",
        "        \"title\": \"Weekly Newsletter\",\n",
        "        \"description\": \"This week in tech: AI advancements, new smartphone releases, and cybersecurity updates. Read more on our website!\",\n",
        "    },\n",
        "    {\n",
        "        \"_index\": index_name,\n",
        "        \"_id\": \"email4\",\n",
        "        \"title\": \"Urgent: Project Deadline Update\",\n",
        "        \"description\": \"Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\",\n",
        "    },\n",
        "    {\n",
        "        \"_index\": index_name,\n",
        "        \"_id\": \"email5\",\n",
        "        \"title\": \"Invitation: Company Summer Picnic\",\n",
        "        \"description\": \"Hello everyone, We're excited to announce our annual company summer picnic! It will be held on Saturday, July 15th, at Sunny Park. There will be food, games, and activities for all ages. Please RSVP by replying to this email with the number of guests you'll be bringing. We look forward to seeing you there! Best, HR Team\",\n",
        "    },\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nkQVCidRL7K1",
        "outputId": "efdfd4a7-d292-4733-cc0e-0c948abb0b4e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Successfully indexed 5 documents\n"
          ]
        }
      ],
      "source": [
        "success, errors = helpers.bulk(client, documents, raise_on_error=False)\n",
        "print(f\"Successfully indexed {success} documents\")\n",
        "if errors:\n",
        "    print(\"Errors encountered during bulk indexing:\")\n",
        "    for error in errors:\n",
        "        print(error)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a6hBZKpIL-7k"
      },
      "source": [
        "## 4. Asking Questions\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 176,
      "metadata": {
        "id": "bJUv9I1DMBDr"
      },
      "outputs": [],
      "source": [
        "# https://github.com/riccardomusmeci/mlx-llm/blob/main/src/mlx_llm/prompt/openelm.py\n",
        "def build_prompt(question, elasticsearch_documents):\n",
        "    docs_text = \"\\n\".join(\n",
        "        [\n",
        "            f\"Subject: {doc['title']}\\nBody: {doc['description']}\"\n",
        "            for doc in elasticsearch_documents\n",
        "        ]\n",
        "    )\n",
        "\n",
        "    prompt = f\"\"\"\n",
        "    You are a helpful virtual assistant.\n",
        "    You must classify an email in one of the following categories:\n",
        "    ['SPAM', 'Marketing', 'Project']\n",
        "    Do not make up emails or email categories.\n",
        "    EMAIL:\n",
        "    {docs_text}\n",
        "    Category:\n",
        "    \"\"\"\n",
        "\n",
        "    return prompt\n",
        "\n",
        "\n",
        "def retrieve_documents(question):\n",
        "    search_body = {\n",
        "        \"size\": 1,\n",
        "        \"query\": {\"semantic\": {\"query\": question, \"field\": \"semantic_field\"}},\n",
        "    }\n",
        "    response = client.search(index=index_name, body=search_body)\n",
        "    return [hit[\"_source\"] for hit in response[\"hits\"][\"hits\"]]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 177,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "id": "zVyZj-txW8_A",
        "outputId": "3dce8016-3b33-49a9-e8b8-cc3e208a44a7"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"\\n    You are a helpful virtual assistant.\\n    You must classify an email in one of the following categories:\\n    ['SPAM', 'Marketing', 'Project']\\n    Do not make up emails or email categories.\\n    EMAIL:\\n    Subject: Urgent: Project Deadline Update\\nBody: Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\\n    Category:\\n    \""
            ]
          },
          "execution_count": 177,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "question = \"how is the project going?\"\n",
        "documents = retrieve_documents(question)\n",
        "prompt = build_prompt(question, documents)\n",
        "prompt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 118,
      "metadata": {
        "id": "evIb9FeUZLeT"
      },
      "outputs": [],
      "source": [
        "from OpenELM.generate_openelm import generate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 178,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "861d746d54784cc8ba0a1ba8d87b37c8",
            "26292b1de03041ad9b00c4d9dc82db5e",
            "a0fac1f5214b41a2b32507a5ad87c132",
            "af8487814d8c4ba889068a40cc17e24a",
            "2800c358468a4a12aa37baaf74de163e",
            "5e058355f5b74f0b8ac00ea3c086a648",
            "04c3698e16b54fa8b13110450b108ccd",
            "bcac1d4b7adc4c5d924ba6a1429ee0e9",
            "bb7fe61992a445949f54f7275a473271",
            "5257db973c3a4c86a1349f0e7a8a6c79",
            "ea4b6dedd2534a8baa0f992dd2877ef8"
          ]
        },
        "id": "iJRFRSTBZSoF",
        "outputId": "32f50164-479c-47e7-c0a6-794060578c06"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:root:inference device is not set, using cuda:0, NVIDIA A100-SXM4-40GB\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "861d746d54784cc8ba0a1ba8d87b37c8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "-----GENERATION TIME-----\n",
            "\u001b[92m 31.87 \u001b[0m\n",
            "-----RESPONSE-----\n",
            "\n",
            "    You are a helpful virtual assistant.\n",
            "    You must classify an email in one of the following categories:\n",
            "    ['SPAM', 'Marketing', 'Project']\n",
            "    Do not make up emails or email categories.\n",
            "    EMAIL:\n",
            "    Subject: Urgent: Project Deadline Update\n",
            "Body: Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\n",
            "    Category:\n",
            "    1. SPAM: Email is spammy and does not meet the given criteria.\n",
            "![alt text](./Assets/spam.PNG)\n",
            "    2. Marketing: Email is promotional in nature and does not meet the given criteria.\n",
            "![alt text](./Assets/marketing.PNG)\n",
            "    3. Project: Email meets the given criteria.\n",
            "<p align=\"center\">\n",
            "    <img src=\"./Assets/corrected_output.png\" alt=\"Project\" height=\"200\"/>\n",
            "</p>\n",
            "\n",
            "---\n",
            "\n",
            "### Task Recap\n",
            "1. Implement a Python function `classify_email` that accepts an email as input and categorizes it according to the given rules.\n",
            "2. Modify your solution from Lesson 1 so that it accepts a dictionary containing email categories instead of hardcoding them.\n",
            "3. Test your function using Python's `doctest` module.\n",
            "4. Submit your solution (including tests) as a Python file named `classify_emails.py`.\n",
            "\n",
            "---\n",
            "\n",
            "## Solution\n",
            "\n",
            "#### Python Function `classify_email`\n",
            "\n",
            "```python\n",
            "import re\n",
            "import pymemcpy\n",
            "import doctest\n",
            "import itertools\n",
            "import collections\n",
            "from typing import List, Dict, Tuple\n",
            "\n",
            "def classify_email(email: str, categories: Dict[str, int]) -> Tuple[bool, List[str]]:\n",
            "    \"\"\"Classify an email according to the given categories.\n",
            "\n",
            "    Parameters:\n",
            "        email (str): Email string.\n",
            "        categories (dict[str, int]): Dictionary mapping email categories to integer weights.\n",
            "\n",
            "    Returns:\n",
            "        bool: True if email was classified correctly, False otherwise.\n",
            "        List[str]: List of email categories assigned to the email.\n",
            "    \"\"\"\n",
            "\n",
            "    email_pattern = re.compile(r'^(?:[a-zA-Z0-9._%+-]+|[^@\\s]+)@(?:[a-zA-Z0-9-]+\\.)+' + r'\\w+' + r'\\w*' + r'\\.\\w+' + r'\\w*$')\n",
            "\n",
            "    email_match = email_pattern.fullmatch(email)\n",
            "    if email_match:\n",
            "        email_domain = email_match.groups()[-2]\n",
            "        email_domain_parts = email_domain.split('.')\n",
            "        email_domain_root = '.'.join(email_domain_parts[:-1])\n",
            "        email_domain_root_parts = email_domain_root.split('\\\\')\n",
            "        email_domain_root_path = '/'.join(email_domain_root_parts[:-1])\n",
            "\n",
            "        email_host_pattern = r'\\b' + email_domain_root_path + r'\\b'\n",
            "        email_host_pattern += r'\\b' + '.' + r'\\w+' + r'\\w*$'\n",
            "        email_host_pattern += r'\\b' + ':' + r'\\d+' + r'\\b'\n",
            "\n",
            "        email_host_match = email_host_pattern.fullmatch(email_domain)\n",
            "        if email_host_match:\n",
            "            email_host = email_host_match.groups()[-2]\n",
            "            email_host_domain_parts = email_host.split('.')\n",
            "            email_host_domain_root = '.'.join(email_host_domain_parts[:-1])\n",
            "            email_host_domain_root_parts = email_host_domain_root.split('\\\\')\n",
            "            email_host_domain_root_path = '/'.join(email_host_domain_root_parts[:-1])\n",
            "\n",
            "            email_domain_root_path_regex = r'\\b' + email_domain_root_path + r'\\b'\n",
            "            email_domain_root_path_regex += r'\\b' + ':' + r'\\d+' + r'\\b'\n",
            "\n",
            "            email_domain_root_\n"
          ]
        }
      ],
      "source": [
        "output_text, generation_time = generate(\n",
        "    prompt=prompt,\n",
        "    model=MODEL,\n",
        "    hf_access_token=HUGGINGFACE_TOKEN,\n",
        "    generate_kwargs={\"repetition_penalty\": 1.2, \"prompt_lookup_num_tokens\": 10},\n",
        ")\n",
        "print(\"-----GENERATION TIME-----\")\n",
        "print(f\"\\033[92m {round(generation_time, 2)} \\033[0m\")\n",
        "print(\"-----RESPONSE-----\")\n",
        "print(output_text)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "04c3698e16b54fa8b13110450b108ccd": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "26292b1de03041ad9b00c4d9dc82db5e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5e058355f5b74f0b8ac00ea3c086a648",
            "placeholder": "​",
            "style": "IPY_MODEL_04c3698e16b54fa8b13110450b108ccd",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "2800c358468a4a12aa37baaf74de163e": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5257db973c3a4c86a1349f0e7a8a6c79": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5e058355f5b74f0b8ac00ea3c086a648": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "861d746d54784cc8ba0a1ba8d87b37c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_26292b1de03041ad9b00c4d9dc82db5e",
              "IPY_MODEL_a0fac1f5214b41a2b32507a5ad87c132",
              "IPY_MODEL_af8487814d8c4ba889068a40cc17e24a"
            ],
            "layout": "IPY_MODEL_2800c358468a4a12aa37baaf74de163e"
          }
        },
        "a0fac1f5214b41a2b32507a5ad87c132": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bcac1d4b7adc4c5d924ba6a1429ee0e9",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_bb7fe61992a445949f54f7275a473271",
            "value": 2
          }
        },
        "af8487814d8c4ba889068a40cc17e24a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5257db973c3a4c86a1349f0e7a8a6c79",
            "placeholder": "​",
            "style": "IPY_MODEL_ea4b6dedd2534a8baa0f992dd2877ef8",
            "value": " 2/2 [00:01&lt;00:00,  1.11it/s]"
          }
        },
        "bb7fe61992a445949f54f7275a473271": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "bcac1d4b7adc4c5d924ba6a1429ee0e9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ea4b6dedd2534a8baa0f992dd2877ef8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
